{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b86c2320-71ed-4223-9df7-0b9281cb652c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('..')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8e9dc80f-43a6-4d8d-9d99-343bc6515ff8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from kg_rag.utility import *\n",
    "from tqdm import tqdm\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b991006-9e91-4db1-9c11-62cbf1d9c356",
   "metadata": {},
   "source": [
    "## Choose the LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ae38918-24e1-4a28-b4e5-461eda38002c",
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM_MODEL = 'gpt-4-32k'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db3c5056-15d6-4608-87c8-1e897dc4075e",
   "metadata": {},
   "source": [
    "## Configure KG-RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fdf4d8fd-2265-4237-ba85-06a3efbf8145",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = system_prompts[\"KG_RAG_BASED_TEXT_GENERATION\"]\n",
    "CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n",
    "QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
    "QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
    "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
    "NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
    "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
    "SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
    "TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n",
    "\n",
    "CHAT_MODEL_ID = LLM_MODEL\n",
    "EDGE_EVIDENCE = True\n",
    "\n",
    "CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID\n",
    "\n",
    "\n",
    "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
    "embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
    "node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "547cf664-8b48-4f19-a232-09a5b2fa4ffa",
   "metadata": {},
   "source": [
    "## Load test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "00fa2491-901e-44ea-8109-2a60b23771ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('data/rag_comparison_data.csv')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c207c9-49be-449b-9b70-a92cdf8095d3",
   "metadata": {},
   "source": [
    "## Function for chat completion with token usage tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8ca41e38-79fb-4f68-aa16-db1785b6551f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        temperature=temperature,\n",
    "        deployment_id=chat_deployment_id,\n",
    "        model=chat_model_id,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": instruction}\n",
    "        ]\n",
    "    )\n",
    "    return response['choices'][0]['message']['content'], response.usage.total_tokens\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b2bbab7-72f6-414b-bdd0-0eab4ed842f2",
   "metadata": {},
   "source": [
    "## Run on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "637671b2-a06c-4fe4-a7a6-855b0ba48fcd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [11:13,  6.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3min 37s, sys: 9.86 s, total: 3min 47s\n",
      "Wall time: 11min 13s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "kg_rag_answer = []\n",
    "total_tokens_used = []\n",
    "\n",
    "for index, row in tqdm(data.iterrows()):\n",
    "    question = row['question']\n",
    "    context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
    "    enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n",
    "    output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
    "    kg_rag_answer.append(output)\n",
    "    total_tokens_used.append(token_usage)\n",
    "    \n",
    "data.loc[:,'kg_rag_answer'] = kg_rag_answer\n",
    "data.loc[:, 'total_tokens_used'] = total_tokens_used\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18e4b72c-c2a5-4b1a-8100-7ad831eb1401",
   "metadata": {},
   "source": [
    "## Run on perturbed test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8a042aa2-2366-4d49-a694-efd6d7b4616b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [09:49,  5.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3min 36s, sys: 9.04 s, total: 3min 45s\n",
      "Wall time: 9min 49s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "kg_rag_answer = []\n",
    "total_tokens_used = []\n",
    "\n",
    "for index, row in tqdm(data.iterrows()):\n",
    "    question = row['question']\n",
    "    context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
    "    enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n",
    "    output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
    "    kg_rag_answer.append(output)\n",
    "    total_tokens_used.append(token_usage)\n",
    "    \n",
    "data.loc[:,'kg_rag_answer_perturbed'] = kg_rag_answer\n",
    "data.loc[:, 'total_tokens_used_perturbed'] = total_tokens_used\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c902260-1d9e-4a52-a377-f0c002c91e16",
   "metadata": {},
   "source": [
    "## Save the result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d510de56-dd39-4742-8a5a-9bb934690d95",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = 'data/results'\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "data.to_csv(os.path.join(save_path, 'kg_rag_output.csv'), index=False)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
